﻿using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using IQToolkit;
using IQToolkit.Data;

namespace IQToolkitContrib {
    public class DbEntityInserterOrUpdater<T> : DbEntityInserterBase<T> {
        public DbEntityInserterOrUpdater(IEntityTable<T> entityTable, DbEntityProvider provider)
            : base(entityTable, provider) {
        }

        protected override Func<T, int> GetFunction() {
            if (this.primaryKey != null && this.IsGeneratedPrimaryKey()) {
                // Create a type array for the generic types for the Insert<T, S> method.
                Type[] genericTypes = new Type[] { this.entityType, this.primaryKey.PropertyType };

                LambdaExpression insertExpression = this.GetExpression(genericTypes);
                MethodInfo insertMethod = this.GetMethodInfo(genericTypes);

                return instance => {
                    object id = insertMethod.Invoke(null, new object[] { this.entityTable, instance, null, insertExpression });

                    // Set the primary key property to the new Id value.
                    this.entityType.GetProperty(this.primaryKey.Name).SetValue(instance, id);

                    // Assuming if we got this far without an exception then the insert worked.
                    return 1;
                };
            }

            // Do the insert statement as normal if the Entity doesn't have an AutoGenerated primary key.
            return instance => this.entityTable.Insert(instance);
        }

        protected override MethodInfo GetMethodInfo(Type[] genericTypes) {
            // MethodInfo for: (IQToolkit.Updatable) public static S Insert<T, S>(this IUpdatable<T> collection, T instance, Expression<Func<T, S>> resultSelector)
            MethodInfo mi = typeof(IQToolkit.Updatable).GetMethods()
                                                       .Where(d => d.Name == "InsertOrUpdate" && d.IsGenericMethod && d.ReturnType.FullName == null)
                                                       .First();

            return mi.MakeGenericMethod(genericTypes);
        }
    }
}